

# ################################################
# Just used in reparameterized MLE + unified approach
# ################################################

estimate_Y <- function(dat, beta_y, beta_l, beta_m, px){
 
 p = length(beta_y)
 beta_f = beta_y[1:(p-2)]
 w0 = beta_y[p-1]
 wa = beta_y[p]
 
 dat_m0l0 = process_data(dat, a = dat$A, m = 0, l = 0)
 dat_m1l0 = process_data(dat, a = dat$A, m = 1, l = 0)
 dat_m0l1 = process_data(dat, a = dat$A, m = 0, l = 1)
 dat_m1l1 = process_data(dat, a = dat$A, m = 1, l = 1)
 dat_a0m0 = process_data(dat, a = 0, m = 0, l = dat$L)
 dat_a0m1 = process_data(dat, a = 0, m = 1, l = dat$L)
 dat_a1 = process_data(dat, a = 1, m = dat$M, l = dat$L)

 # p(M | A = 1, C)   
 idx_m = match(attributes(beta_m)$names, colnames(dat))
 p_m1a1 = 1/(1 + exp(-dat_a1[, idx_m]%*%beta_m))
 p_m0a1 = 1 - p_m1a1
 
 # p(L | A = 0, M, C)
 idx_l = match(attributes(beta_l)$names, colnames(dat))
 p_l1a0m0 = 1/(1 + exp(-dat_a0m0[, idx_l]%*%beta_l))
 p_l0a0m0 = 1 - p_l1a0m0
 p_l1a0m1 = 1/(1 + exp(-dat_a0m1[, idx_l]%*%beta_l))
 p_l0a0m1 = 1 - p_l1a0m1
 
 # 1/n { \sum_i \sum_{m, l} {f_lmAc*p(L=l|A=0,M=m,c)*p(M=m|A=1,c)} }
 idx_f = match(attributes(beta_f)$names, colnames(dat))
 f_m0l0ac = dat_m0l0[, idx_f]%*%beta_f
 f_m1l0ac = dat_m1l0[, idx_f]%*%beta_f
 f_m0l1ac = dat_m0l1[, idx_f]%*%beta_f
 f_m1l1ac = dat_m1l1[, idx_f]%*%beta_f
 
 sum_train = sum(px*( f_m0l0ac*p_l0a0m0*p_m0a1 + f_m1l0ac*p_l0a0m1*p_m1a1 + f_m0l1ac*p_l1a0m0*p_m0a1 + f_m1l1ac*p_l1a0m1*p_m1a1))
 
 # f(M, A, C)
 f_mac = as.matrix(dat[, idx_f])%*%beta_f
 
 # E[Y | A, C, M] = f - sum_train + w0 + wa*A
 Y_hat = f_mac - sum_train + w0 + wa*dat$A
 
 return(Y_hat)
}


# ################################################
# Prediction in general
# ################################################

# G-formula: sum out {M, L}

compute_mse <- function(dat, beta, px, opt){
 
 reparam = opt$reparam
 estimator = opt$estimator 
 
 beta_y = beta$beta_y
 beta_l = beta$beta_l
 beta_m = beta$beta_m
 beta_a = beta$beta_a
 
 dat_m0l0 = process_data(dat, a = dat$A, m = 0, l = 0)
 dat_m0l1 = process_data(dat, a = dat$A, m = 0, l = 1)
 dat_m1l0 = process_data(dat, a = dat$A, m = 1, l = 0)
 dat_m1l1 = process_data(dat, a = dat$A, m = 1, l = 1)
 
 dat_m0 = process_data(dat, a = dat$A, m = 0, l = dat$L)
 dat_m1 = process_data(dat, a = dat$A, m = 1, l = dat$L)
 
 # p*( L | M, A, C)
 idx_l = match(attributes(beta_l)$names, colnames(dat))
 p_l1m1 = 1/(1 + exp(-as.matrix(dat_m1[, idx_l])%*%beta_l))
 p_l0m1 = 1 - p_l1m1
 p_l1m0 = 1/(1 + exp(-as.matrix(dat_m0[, idx_l])%*%beta_l))
 p_l0m0 = 1 - p_l1m0
 
 # p*( M | A, C)
 idx_m = match(attributes(beta_m)$names, colnames(dat))
 p_m1 = 1/(1 + exp(-as.matrix(dat[, idx_m])%*%beta_m))
 p_m0 = 1 - p_m1
 
  
 if (reparam == FALSE){
  
  # E*[Y | M = m, A, C]
  idx_y = match(attributes(beta_y)$names, colnames(dat))
  Yhat_m0l0 = as.matrix(dat_m0l0[, idx_y])%*%beta_y
  Yhat_m0l1 = as.matrix(dat_m0l1[, idx_y])%*%beta_y
  Yhat_m1l0 = as.matrix(dat_m1l0[, idx_y])%*%beta_y
  Yhat_m1l1 = as.matrix(dat_m1l1[, idx_y])%*%beta_y
  
 }else{
  
  # E*[Y | A, C, M=m] = f(m, .) - \sum_i {\sum_{m,l} f_i*p(L | a=0, ci)*p(M | a=1, ci)} + w0
  Yhat_m0l0 = estimate_Y(as.data.frame(dat_m0l0), beta_y, beta_l, beta_m, px)
  Yhat_m0l1 = estimate_Y(as.data.frame(dat_m0l1), beta_y, beta_l, beta_m, px)
  Yhat_m1l0 = estimate_Y(as.data.frame(dat_m1l0), beta_y, beta_l, beta_m, px)
  Yhat_m1l1 = estimate_Y(as.data.frame(dat_m1l1), beta_y, beta_l, beta_m, px)
 }
  
 # E*[Y | A, C] = \sum_M E*[Y | A, C, M] p*(M | A, C)
 Y_hat = Yhat_m0l0*p_l0m0*p_m0 + Yhat_m0l1*p_l1m0*p_m0 + Yhat_m1l0*p_l0m1*p_m1 + Yhat_m1l1*p_l1m1*p_m1 
 
 MSE = mean((dat$Y - Y_hat)^2)

 return(MSE)
}

